{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原文代码作者：https://github.com/wzyonggege/statistical-learning-method\n",
    "\n",
    "中文注释制作：机器学习初学者(微信公众号：ID:ai-start-com)\n",
    "\n",
    "配置环境：python 3.6\n",
    "\n",
    "代码全部测试通过。\n",
    "![gongzhong](../gongzhong.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第8章 提升方法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Boost\n",
    "\n",
    "“装袋”（bagging）和“提升”（boost）是构建组合模型的两种最主要的方法，所谓的组合模型是由多个基本模型构成的模型，组合模型的预测效果往往比任意一个基本模型的效果都要好。\n",
    "\n",
    "- 装袋：每个基本模型由从总体样本中随机抽样得到的不同数据集进行训练得到，通过重抽样得到不同训练数据集的过程称为装袋。\n",
    "\n",
    "- 提升：每个基本模型训练时的数据集采用不同权重，针对上一个基本模型分类错误的样本增加权重，使得新的模型重点关注误分类样本\n",
    "\n",
    "### AdaBoost\n",
    "\n",
    "AdaBoost是AdaptiveBoost的缩写，表明该算法是具有适应性的提升算法。\n",
    "\n",
    "算法的步骤如下：\n",
    "\n",
    "1）给每个训练样本（$x_{1},x_{2},….,x_{N}$）分配权重，初始权重$w_{1}$均为1/N。\n",
    "\n",
    "2）针对带有权值的样本进行训练，得到模型$G_m$（初始模型为G1）。\n",
    "\n",
    "3）计算模型$G_m$的误分率$e_m=\\sum_{i=1}^Nw_iI(y_i\\not= G_m(x_i))$\n",
    "\n",
    "4）计算模型$G_m$的系数$\\alpha_m=0.5\\log[(1-e_m)/e_m]$\n",
    "\n",
    "5）根据误分率e和当前权重向量$w_m$更新权重向量$w_{m+1}$。\n",
    "\n",
    "6）计算组合模型$f(x)=\\sum_{m=1}^M\\alpha_mG_m(x_i)$的误分率。\n",
    "\n",
    "7）当组合模型的误分率或迭代次数低于一定阈值，停止迭代；否则，回到步骤2）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection  import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data\n",
    "def create_data():\n",
    "    iris = load_iris()\n",
    "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "    df['label'] = iris.target\n",
    "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
    "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "    for i in range(len(data)):\n",
    "        if data[i,-1] == 0:\n",
    "            data[i,-1] = -1\n",
    "    # print(data)\n",
    "    return data[:,:2], data[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x1bad866a630>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGihJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtP0yjQNfCCJImxAV3q7UWgiyU4I8qC7sGFwwuRgxBbUzAkOCPJdEUyALCFrsVS2H5sQhLVAI1U8B2bSWCoJ2BXYZii6wFyvDdP+6ddubOnbn3ufeeuc/z3M8raTr33Ken3+cc/XJ7zuc819wdERHJy5+1uwAREWk9NXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSof3qHWhmXcAQMOLuyyreWwlcA4yUN13n7jfMtL9DDjnE58+fH1SsiEin27Rp00vu3ldrXN3NHbgE2AYcOM37P3T3z9e7s/nz5zM0NBTw14uIiJn9rp5xdV2WMbMB4KPAjJ/GRUQkDvVec/828CXgrRnGfNzMNpvZOjObV22AmV1oZkNmNjQ6Ohpaq4iI1KlmczezZcCL7r5phmF3AfPdfSHwE+DmaoPcfbW7D7r7YF9fzUtGIiLSoHquuS8GlpvZUmAOcKCZ3erunxgf4O47Joy/Hvhma8sUEWmdPXv2MDw8zGuvvdbuUqY1Z84cBgYG6O7ubujP12zu7n45cDmAmX0Q+OeJjb28/VB3f6H8cjmlG68iIlEaHh7mgAMOYP78+ZhZu8uZwt3ZsWMHw8PDHHHEEQ3to+Gcu5mtMrPl5ZcXm9mvzOyXwMXAykb3KyJStNdee425c+dG2dgBzIy5c+c29S+LkCgk7v4w8HD55ysnbN/76V4kN+ufGOGa+5/i+Z27Oay3h8uWLOCMRf3tLkuaFGtjH9dsfUHNXaTTrH9ihMvv2MLuPWMAjOzczeV3bAFQg5eoafkBkRlcc/9Texv7uN17xrjm/qfaVJHk4r777mPBggUcffTRXH311S3fv5q7yAye37k7aLtIPcbGxrjooou499572bp1K2vWrGHr1q0t/Tt0WUZkBof19jBSpZEf1tvThmqkXVp93+UXv/gFRx99NEceeSQAK1as4M477+S9731vq0rWJ3eRmVy2ZAE93V2TtvV0d3HZkgVtqkhm2/h9l5Gdu3H23XdZ/8RIzT87nZGREebN2/cg/8DAACMjje+vGjV3kRmcsaifq848jv7eHgzo7+3hqjOP083UDlLEfRd3n7Kt1ekdXZYRqeGMRf1q5h2siPsuAwMDbN++fe/r4eFhDjvssIb3V40+uYuIzGC6+yvN3Hd5//vfz29+8xueffZZ3njjDW6//XaWL19e+w8GUHMXEZlBEfdd9ttvP6677jqWLFnCe97zHs4++2yOOeaYZkud/He0dG8iIpkZvyTX6qeUly5dytKlS1tRYlVq7iIiNaR430WXZUREMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7pKN9U+MsPjqhzjiy//B4qsfamrtD5Giffazn+Wd73wnxx57bCH7V3OXLBSxuJNIkVauXMl9991X2P7V3CUL+lINKdTmtfCtY+FrvaXfN69tepennHIKBx98cAuKq04PMUkW9KUaUpjNa+Gui2FP+X9Lu7aXXgMsPLt9ddWgT+6ShSIWdxIB4MFV+xr7uD27S9sjpuYuWdCXakhhdg2HbY+ELstIFopa3EmEgwZKl2KqbY+YmrtkI8XFnSQBp105+Zo7QHdPaXsTzj33XB5++GFeeuklBgYG+PrXv87555/fZLH7qLlL01r95cEiURm/afrgqtKlmIMGSo29yZupa9asaUFx01Nzl6aM58vHY4jj+XJADV7ysfDsqJMx1eiGqjRF+XKROKm5S1OUL5dUuXu7S5hRs/WpuUtTlC+XFM2ZM4cdO3ZE2+DdnR07djBnzpyG96Fr7tKUy5YsmHTNHZQvl/gNDAwwPDzM6Ohou0uZ1pw5cxgYaDxuqeYuTVG+XFLU3d3NEUcc0e4yClV3czezLmAIGHH3ZRXv7Q/cApwA7ADOcffnWlinREz5cpH4hHxyvwTYBhxY5b3zgT+4+9FmtgL4JnBOC+oTSYoy/xKLum6omtkA8FHghmmGnA7cXP55HXCamVnz5YmkQ2vKS0zqTct8G/gS8NY07/cD2wHc/U1gFzC36epEEqLMv8SkZnM3s2XAi+6+aaZhVbZNyRiZ2YVmNmRmQzHfpRZphDL/EpN6PrkvBpab2XPA7cCpZnZrxZhhYB6Ame0HHAS8XLkjd1/t7oPuPtjX19dU4SKxUeZfYlKzubv75e4+4O7zgRXAQ+7+iYphG4BPl38+qzwmzqcDRAqiNeUlJg3n3M1sFTDk7huAG4EfmNnTlD6xr2hRfSLJUOZfYmLt+oA9ODjoQ0NDbfm7RURSZWab3H2w1jg9oSrRumL9FtZs3M6YO11mnHvSPL5xxnHtLkskCWruEqUr1m/h1sd+v/f1mPve12rwIrVpVUiJ0pqNVb6zcobtIjKZmrtEaWyae0HTbReRydTcJUpd06xeMd12EZlMzV2idO5J84K2i8hkuqEqURq/aaq0jEhjlHMXEUmIcu7SlPOuf5RHntm3PNDiow7mtgtObmNF7aM12iVFuuYuU1Q2doBHnnmZ865/tE0VtY/WaJdUqbnLFJWNvdb2nGmNdkmVmrvIDLRGu6RKzV1kBlqjXVKl5i5TLD7q4KDtOdMa7ZIqNXeZ4rYLTp7SyDs1LXPGon6uOvM4+nt7MKC/t4erzjxOaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcC5ZiBjvFmkWkMWruDUgxg51izSLSON1QbUCKGewUaxaRxqm5NyDFDHaKNYtI49TcG5BiBjvFmkWkcWruDUgxg51izSLSON1QbUCKGewUaxaRxtXMuZvZHOCnwP6U/mOwzt2/WjFmJXANMP6V8Ne5+w0z7Vc5dxGRcK3Mub8OnOrur5pZN/BzM7vX3R+rGPdDd/98I8XK7Lhi/RbWbNzOmDtdZpx70jy+ccZxTY+NJT8fSx0iMajZ3L300f7V8svu8q/2PNYqDbti/RZufez3e1+Pue99Xdm0Q8bGkp+PpQ6RWNR1Q9XMuszsSeBF4AF331hl2MfNbLOZrTOzeS2tUpq2ZuP2ureHjI0lPx9LHSKxqKu5u/uYu78PGABONLNjK4bcBcx394XAT4Cbq+3HzC40syEzGxodHW2mbgk0Ns29lWrbQ8bGkp+PpQ6RWARFId19J/Aw8OGK7Tvc/fXyy+uBE6b586vdfdDdB/v6+hooVxrVZVb39pCxseTnY6lDJBY1m7uZ9ZlZb/nnHuBDwK8rxhw64eVyYFsri5TmnXtS9Stl1baHjI0lPx9LHSKxqCctcyhws5l1UfqPwVp3v9vMVgFD7r4BuNjMlgNvAi8DK4sqWBozfiO0ngRMyNhY8vOx1CESC63nLiKSEK3nXrCiMtUh+fIi9x0yvxSPRXI2r4UHV8GuYThoAE67Ehae3e6qJGJq7g0oKlMdki8vct8h80vxWCRn81q462LYU07+7Npeeg1q8DItLRzWgKIy1SH58iL3HTK/FI9Fch5cta+xj9uzu7RdZBpq7g0oKlMdki8vct8h80vxWCRn13DYdhHU3BtSVKY6JF9e5L5D5pfisUjOQQNh20VQc29IUZnqkHx5kfsOmV+KxyI5p10J3RX/sezuKW0XmYZuqDagqEx1SL68yH2HzC/FY5Gc8ZumSstIAOXcRUQSopy7TBFDdl0Sp7x9MtTcO0QM2XVJnPL2SdEN1Q4RQ3ZdEqe8fVLU3DtEDNl1SZzy9klRc+8QMWTXJXHK2ydFzb1DxJBdl8Qpb58U3VDtEDFk1yVxytsnRTl3EZGEKOdeVlReO2S/saxLrux6ZHLPjOc+vxBtOBZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vhk3dyLymuH7DeWdcmVXY9M7pnx3OcXok3HIuvmXlReO2S/saxLrux6ZHLPjOc+vxBtOhZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vgo5y4ikhDl3AsWQ37+vOsf5ZFnXt77evFRB3PbBSc3XYNIVu6+FDbdBD4G1gUnrIRl1za/38hz/Flfcy/KeGZ8ZOdunH2Z8fVPjMzafisbO8Ajz7zMedc/2lQNIlm5+1IYurHU2KH0+9CNpe3NGM+u79oO+L7s+ua1TZfcKmruDYghP1/Z2GttF+lIm24K216vBHL8au4NiCE/LyJ18LGw7fVKIMev5t6AGPLzIlIH6wrbXq8Ecvxq7g2IIT+/+KiDq+5juu0iHemElWHb65VAjl/NvQFnLOrnqjOPo7+3BwP6e3u46szjWpKfr3e/t11w8pRGrrSMSIVl18Lg+fs+qVtX6XWzaZmFZ8PHvgsHzQOs9PvHvhtVWkY5dxGRhLQs525mc4CfAvuXx69z969WjNkfuAU4AdgBnOPuzzVQd02h+fLU1jAPWfs992NRaI44JPtcVB1Fzi/yDHZTQueW87GYQT0PMb0OnOrur5pZN/BzM7vX3R+bMOZ84A/ufrSZrQC+CZzT6mJD1yRPbQ3zkLXfcz8Wha6BPZ59HjeefYapDb6oOoqcX85rqYfOLedjUUPNa+5e8mr5ZXf5V+W1nNOBm8s/rwNOM2v9soeh+fLU1jAPWfs992NRaI44JPtcVB1Fzi+BDHbDQueW87Gooa4bqmbWZWZPAi8CD7j7xooh/cB2AHd/E9gFzK2ynwvNbMjMhkZHR4OLDc2Bp5YbD1n7PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uY+7+PmAAONHMjq0YUu1T+pSO5O6r3X3Q3Qf7+vqCiw3NgaeWGw9Z+z33Y1Fojjgk+1xUHUXOL4EMdsNC55bzsaghKArp7juBh4EPV7w1DMwDMLP9gIOAlj8HH5ovT20N85C133M/FoXmiEOyz0XVUeT8EshgNyx0bjkfixrqScv0AXvcfaeZ9QAfonTDdKINwKeBR4GzgIe8gIxl6Jrkqa1hHrL2e+7HotA1sMdvmtaTlimqjiLnl/Na6qFzy/lY1FAz525mCyndLO2i9El/rbuvMrNVwJC7byjHJX8ALKL0iX2Fu/92pv0q5y4iEq5lOXd330ypaVduv3LCz68BfxdapIiIFCP7L+tI7sEdmR0hD7bE8BBMkQ/upPaQVgznIwFZN/fkHtyR2RHyYEsMD8EU+eBOag9pxXA+EpH1wmHJPbgjsyPkwZYYHoIp8sGd1B7SiuF8JCLr5p7cgzsyO0IebInhIZgiH9xJ7SGtGM5HIrJu7sk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uRX2phiQu5IsWYvhShtAaYphfavvNkL6sQ0QkIS17iEmk44V8sUcsUqs5lux6LHW0gJq7yExCvtgjFqnVHEt2PZY6WiTra+4iTQv5Yo9YpFZzLNn1WOpoETV3kZmEfLFHLFKrOZbseix1tIiau8hMQr7YIxap1RxLdj2WOlpEzV1kJiFf7BGL1GqOJbseSx0touYuMpNl18Lg+fs+9VpX6XWMNybHpVZzLNn1WOpoEeXcRUQSopy7zJ4Us8FF1VxUvjzFYyxtpeYuzUkxG1xUzUXly1M8xtJ2uuYuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2n5i7NSTEbXFTNReXLUzzG0nZq7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdsq5i4gkpN6cuz65Sz42r4VvHQtf6y39vnnt7O+3qBpEAinnLnkoKgsesl/l0SUi+uQueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSI1c+5mNg+4BXgX8Baw2t2/UzHmg8CdwLPlTXe4+4x3kZRzFxEJ18r13N8Evujuj5vZAcAmM3vA3bdWjPuZuy9rpFiJUIrrh4fUnOL8YqDjloyazd3dXwBeKP/8RzPbBvQDlc1dcpFiXlt59OLpuCUl6Jq7mc0HFgEbq7x9spn90szuNbNjWlCbtEuKeW3l0Yun45aUup9QNbN3AD8CvuDur1S8/Tjwl+7+qpktBdYD766yjwuBCwEOP/zwhouWgqWY11YevXg6bkmp65O7mXVTauy3ufsdle+7+yvu/mr553uAbjM7pMq41e4+6O6DfX19TZYuhUkxr608evF03JJSs7mbmQE3AtvcverapWb2rvI4zOzE8n53tLJQmUUp5rWVRy+ejltS6rkssxj4JLDFzJ4sb/sKcDiAu38fOAv4nJm9CewGVni71hKW5o3fHEspFRFSc4rzi4GOW1K0nruISEJamXOXWClzPNndl8Kmm0pfSG1dpa+3a/ZbkEQSpeaeKmWOJ7v7Uhi6cd9rH9v3Wg1eOpDWlkmVMseTbbopbLtI5tTcU6XM8WQ+FrZdJHNq7qlS5ngy6wrbLpI5NfdUKXM82Qkrw7aLZE7NPVVaO3yyZdfC4Pn7PqlbV+m1bqZKh1LOXUQkIcq5N2D9EyNcc/9TPL9zN4f19nDZkgWcsai/3WW1Tu65+NznFwMd42SouZetf2KEy+/Ywu49pXTFyM7dXH7HFoA8Gnzuufjc5xcDHeOk6Jp72TX3P7W3sY/bvWeMa+5/qk0VtVjuufjc5xcDHeOkqLmXPb9zd9D25OSei899fjHQMU6KmnvZYb09QduTk3suPvf5xUDHOClq7mWXLVlAT/fkB156uru4bMmCNlXUYrnn4nOfXwx0jJOiG6pl4zdNs03L5L4Wd+7zi4GOcVKUcxcRSUi9OXddlhFJwea18K1j4Wu9pd83r01j39I2uiwjErsi8+XKrmdLn9xFYldkvlzZ9WypuYvErsh8ubLr2VJzF4ldkflyZdezpeYuErsi8+XKrmdLzV0kdkWu3a/vBciWcu4iIglRzl1EpIOpuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSoZrN3czmmdl/mdk2M/uVmV1SZYyZ2XfN7Gkz22xmxxdTrjRF63aLdIx61nN/E/iiuz9uZgcAm8zsAXffOmHMR4B3l3+dBHyv/LvEQut2i3SUmp/c3f0Fd3+8/PMfgW1A5ReLng7c4iWPAb1mdmjLq5XGad1ukY4SdM3dzOYDi4CNFW/1A9snvB5m6n8AMLMLzWzIzIZGR0fDKpXmaN1ukY5Sd3M3s3cAPwK+4O6vVL5d5Y9MWZHM3Ve7+6C7D/b19YVVKs3Rut0iHaWu5m5m3ZQa+23ufkeVIcPAvAmvB4Dnmy9PWkbrdot0lHrSMgbcCGxz92unGbYB+FQ5NfMBYJe7v9DCOqVZWrdbpKPUk5ZZDHwS2GJmT5a3fQU4HMDdvw/cAywFngb+BHym9aVK0xaerWYu0iFqNnd3/znVr6lPHOPARa0qSkREmqMnVEVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJUi6m34i81Ggd+15S+v7RDgpXYXUSDNL105zw00v3r8pbvXXJyrbc09ZmY25O6D7a6jKJpfunKeG2h+raTLMiIiGVJzFxHJkJp7davbXUDBNL905Tw30PxaRtfcRUQypE/uIiIZ6ujmbmZdZvaEmd1d5b2VZjZqZk+Wf/19O2pshpk9Z2ZbyvUPVXnfzOy7Zva0mW02s+PbUWcj6pjbB81s14Tzl9RXTplZr5mtM7Nfm9k2Mzu54v1kzx3UNb9kz5+ZLZhQ95Nm9oqZfaFiTOHnr54v68jZJcA24MBp3v+hu39+Fuspwl+7+3S52o8A7y7/Ogn4Xvn3VMw0N4CfufuyWaumtb4D3OfuZ5nZnwNvq3g/9XNXa36Q6Plz96eA90HpAyQwAvy4Yljh569jP7mb2QDwUeCGdtfSRqcDt3jJY0CvmR3a7qI6nZkdCJxC6estcfc33H1nxbBkz12d88vFacAz7l75wGbh569jmzvwbeBLwFszjPl4+Z9M68xs3gzjYuXAf5rZJjO7sMr7/cD2Ca+Hy9tSUGtuACeb2S/N7F4zO2Y2i2vSkcAo8K/ly4Y3mNnbK8akfO7qmR+ke/4mWgGsqbK98PPXkc3dzJYBL7r7phmG3QXMd/eFwE+Am2eluNZa7O7HU/on4EVmdkrF+9W+PjGV+FStuT1O6THtvwL+BVg/2wU2YT/geOB77r4I+D/gyxVjUj539cwv5fMHQPly03Lg36u9XWVbS89fRzZ3Sl/6vdzMngNuB041s1snDnD3He7+evnl9cAJs1ti89z9+fLvL1K65ndixZBhYOK/SAaA52enuubUmpu7v+Lur5Z/vgfoNrNDZr3QxgwDw+6+sfx6HaVmWDkmyXNHHfNL/PyN+wjwuLv/b5X3Cj9/Hdnc3f1ydx9w9/mU/tn0kLt/YuKYiutfyyndeE2Gmb3dzA4Y/xn4W+C/K4ZtAD5VvnP/AWCXu78wy6UGq2duZvYuM7PyzydS+t/6jtmutRHu/j/AdjNbUN50GrC1YliS5w7qm1/K52+Cc6l+SQZm4fx1elpmEjNbBQy5+wbgYjNbDrwJvAysbGdtDfgL4Mfl/3/sB/ybu99nZv8I4O7fB+4BlgJPA38CPtOmWkPVM7ezgM+Z2ZvAbmCFp/XE3j8Bt5X/af9b4DOZnLtxteaX9Pkzs7cBfwP8w4Rts3r+9ISqiEiGOvKyjIhI7tTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcnQ/wPmMFqpaGCFHwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X[:50,0],X[:50,1], label='0')\n",
    "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "\n",
    "### AdaBoost in Python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AdaBoost:\n",
    "    def __init__(self, n_estimators=50, learning_rate=1.0):\n",
    "        self.clf_num = n_estimators\n",
    "        self.learning_rate = learning_rate\n",
    "    \n",
    "    def init_args(self, datasets, labels):\n",
    "        \n",
    "        self.X = datasets\n",
    "        self.Y = labels\n",
    "        self.M, self.N = datasets.shape\n",
    "        \n",
    "        # 弱分类器数目和集合\n",
    "        self.clf_sets = []\n",
    "        \n",
    "        # 初始化weights\n",
    "        self.weights = [1.0/self.M]*self.M\n",
    "        \n",
    "        # G(x)系数 alpha\n",
    "        self.alpha = []\n",
    "        \n",
    "    def _G(self, features, labels, weights):\n",
    "        m = len(features)\n",
    "        error = 100000.0 # 无穷大\n",
    "        best_v = 0.0\n",
    "        # 单维features\n",
    "        features_min = min(features)\n",
    "        features_max = max(features)\n",
    "        n_step = (features_max - features_min + self.learning_rate) // self.learning_rate\n",
    "        # print('n_step:{}'.format(n_step))\n",
    "        direct, compare_array = None, None\n",
    "        for i in range(1, int(n_step)):\n",
    "            v = features_min + self.learning_rate * i\n",
    "            \n",
    "            if v not in features:\n",
    "                # 误分类计算\n",
    "                compare_array_positive = np.array([1 if features[k] > v else -1 for k in range(m)])\n",
    "                weight_error_positive = sum([weights[k] for k in range(m) if compare_array_positive[k] != labels[k]])\n",
    "                \n",
    "                compare_array_nagetive = np.array([-1 if features[k] > v else 1 for k in range(m)])\n",
    "                weight_error_nagetive = sum([weights[k] for k in range(m) if compare_array_nagetive[k] != labels[k]])\n",
    "\n",
    "                if weight_error_positive < weight_error_nagetive:\n",
    "                    weight_error = weight_error_positive\n",
    "                    _compare_array = compare_array_positive\n",
    "                    direct = 'positive'\n",
    "                else:\n",
    "                    weight_error = weight_error_nagetive\n",
    "                    _compare_array = compare_array_nagetive\n",
    "                    direct = 'nagetive'\n",
    "                    \n",
    "                # print('v:{} error:{}'.format(v, weight_error))\n",
    "                if weight_error < error:\n",
    "                    error = weight_error\n",
    "                    compare_array = _compare_array\n",
    "                    best_v = v\n",
    "        return best_v, direct, error, compare_array\n",
    "        \n",
    "    # 计算alpha\n",
    "    def _alpha(self, error):\n",
    "        return 0.5 * np.log((1-error)/error)\n",
    "    \n",
    "    # 规范化因子\n",
    "    def _Z(self, weights, a, clf):\n",
    "        return sum([weights[i]*np.exp(-1*a*self.Y[i]*clf[i]) for i in range(self.M)])\n",
    "        \n",
    "    # 权值更新\n",
    "    def _w(self, a, clf, Z):\n",
    "        for i in range(self.M):\n",
    "            self.weights[i] = self.weights[i]*np.exp(-1*a*self.Y[i]*clf[i])/ Z\n",
    "    \n",
    "    # G(x)的线性组合\n",
    "    def _f(self, alpha, clf_sets):\n",
    "        pass\n",
    "    \n",
    "    def G(self, x, v, direct):\n",
    "        if direct == 'positive':\n",
    "            return 1 if x > v else -1 \n",
    "        else:\n",
    "            return -1 if x > v else 1 \n",
    "    \n",
    "    def fit(self, X, y):\n",
    "        self.init_args(X, y)\n",
    "        \n",
    "        for epoch in range(self.clf_num):\n",
    "            best_clf_error, best_v, clf_result = 100000, None, None\n",
    "            # 根据特征维度, 选择误差最小的\n",
    "            for j in range(self.N):\n",
    "                features = self.X[:, j]\n",
    "                # 分类阈值，分类误差，分类结果\n",
    "                v, direct, error, compare_array = self._G(features, self.Y, self.weights)\n",
    "                \n",
    "                if error < best_clf_error:\n",
    "                    best_clf_error = error\n",
    "                    best_v = v\n",
    "                    final_direct = direct\n",
    "                    clf_result = compare_array\n",
    "                    axis = j\n",
    "                    \n",
    "                # print('epoch:{}/{} feature:{} error:{} v:{}'.format(epoch, self.clf_num, j, error, best_v))\n",
    "                if best_clf_error == 0:\n",
    "                    break\n",
    "                \n",
    "            # 计算G(x)系数a\n",
    "            a = self._alpha(best_clf_error)\n",
    "            self.alpha.append(a)\n",
    "            # 记录分类器\n",
    "            self.clf_sets.append((axis, best_v, final_direct))\n",
    "            # 规范化因子\n",
    "            Z = self._Z(self.weights, a, clf_result)\n",
    "            # 权值更新\n",
    "            self._w(a, clf_result, Z)\n",
    "            \n",
    "#             print('classifier:{}/{} error:{:.3f} v:{} direct:{} a:{:.5f}'.format(epoch+1, self.clf_num, error, best_v, final_direct, a))\n",
    "#             print('weight:{}'.format(self.weights))\n",
    "#             print('\\n')\n",
    "    \n",
    "    def predict(self, feature):\n",
    "        result = 0.0\n",
    "        for i in range(len(self.clf_sets)):\n",
    "            axis, clf_v, direct = self.clf_sets[i]\n",
    "            f_input = feature[axis]\n",
    "            result += self.alpha[i] * self.G(f_input, clf_v, direct)\n",
    "        # sign\n",
    "        return 1 if result > 0 else -1\n",
    "    \n",
    "    def score(self, X_test, y_test):\n",
    "        right_count = 0\n",
    "        for i in range(len(X_test)):\n",
    "            feature = X_test[i]\n",
    "            if self.predict(feature) == y_test[i]:\n",
    "                right_count += 1\n",
    "        \n",
    "        return right_count / len(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 例8.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.arange(10).reshape(10, 1)\n",
    "y = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = AdaBoost(n_estimators=3, learning_rate=0.5)\n",
    "clf.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6363636363636364"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = AdaBoost(n_estimators=10, learning_rate=0.2)\n",
    "clf.fit(X_train, y_train)\n",
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average score:65.000%\n"
     ]
    }
   ],
   "source": [
    "# 100次结果\n",
    "result = []\n",
    "for i in range(1, 101):\n",
    "    X, y = create_data()\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)\n",
    "    clf = AdaBoost(n_estimators=100, learning_rate=0.2)\n",
    "    clf.fit(X_train, y_train)\n",
    "    r = clf.score(X_test, y_test)\n",
    "    # print('{}/100 score：{}'.format(i, r))\n",
    "    result.append(r)\n",
    "\n",
    "print('average score:{:.3f}%'.format(sum(result)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-----\n",
    "# sklearn.ensemble.AdaBoostClassifier\n",
    "\n",
    "- algorithm：这个参数只有AdaBoostClassifier有。主要原因是scikit-learn实现了两种Adaboost分类算法，SAMME和SAMME.R。两者的主要区别是弱学习器权重的度量，SAMME使用了和我们的原理篇里二元分类Adaboost算法的扩展，即用对样本集分类效果作为弱学习器权重，而SAMME.R使用了对样本集分类的预测概率大小来作为弱学习器权重。由于SAMME.R使用了概率度量的连续值，迭代一般比SAMME快，因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。我们一般使用默认的SAMME.R就够了，但是要注意的是使用了SAMME.R， 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。\n",
    "\n",
    "- n_estimators： AdaBoostClassifier和AdaBoostRegressor都有，就是我们的弱学习器的最大迭代次数，或者说最大的弱学习器的个数。一般来说n_estimators太小，容易欠拟合，n_estimators太大，又容易过拟合，一般选择一个适中的数值。默认是50。在实际调参的过程中，我们常常将n_estimators和下面介绍的参数learning_rate一起考虑。\n",
    "\n",
    "-  learning_rate:  AdaBoostClassifier和AdaBoostRegressor都有，即每个弱学习器的权重缩减系数ν\n",
    "\n",
    "- base_estimator：AdaBoostClassifier和AdaBoostRegressor都有，即我们的弱分类学习器或者弱回归学习器。理论上可以选择任何一个分类或者回归学习器，不过需要支持样本权重。我们常用的一般是CART决策树或者神经网络MLP。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\programdata\\anaconda3\\envs\\tf\\lib\\site-packages\\sklearn\\ensemble\\weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.\n",
      "  from numpy.core.umath_tests import inner1d\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None,\n",
       "          learning_rate=0.5, n_estimators=100, random_state=None)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.ensemble import AdaBoostClassifier\n",
    "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n",
    "clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9393939393939394"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, y_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
